#!/usr/bin/env python3
"""
Phase 5b: Enhanced Bilateral Projections with WEO GDP and GE Clearing

Improvements over phase5_projections.py:
1. Uses WEO GDP projections through 2030 (not frozen at 2024)
2. For 2040/2050, uses 2030 GDP (last WEO year) with sensitivity notes
3. Adds bilateral GE clearing overlay:
   - Computes country-specific demographic yield pressure from S1 coefficients
   - Computes GDP-weighted global Δr* that clears aggregate capital market
   - Applies bilateral adjustment: -0.161 × (Δyield_i - Δyield_j - Δr*_world)

Output:
  - bilateral_projections_ge.csv: full projections with PE and GE columns
  - ge_bilateral_clearing_rates.csv: global Δr* by projection year
  - projection_summary_ge.csv: country-level PE vs GE comparison
"""

import pandas as pd
import numpy as np
from pathlib import Path

BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
OUTPUT_DIR = BASE_DIR / "output" / "tables"

# Model 2c coefficients (gravity + demographics + KAOPEN interactions)
COEFS_2C = {
    'log_gdp_product': 0.728381,
    'dZ_1': 3.961262,
    'dZ_2': -0.523129,
    'dZ_3': 0.020360,
    'dZ_1_x_kaopen_j': 1.164013,
    'dZ_2_x_kaopen_j': -0.139505,
    'dZ_3_x_kaopen_j': 0.004784,
    'kaopen_j': 0.217368,
}

# S1 coefficients: Z → bond yields (23 OECD, from rate_channel_tests.csv)
S1_COEFS = {
    'Z_1': 16.319744,
    'Z_2': -2.074562,
    'Z_3': 0.071788,
}

# Model 2f bilateral rate semi-elasticity
BILATERAL_RATE_COEF = -0.161  # fitted_rate_diff_ij coefficient

# Multilateral rate-to-CA coefficient (Model 3b, for clearing condition)
DELTA_RATE_MULTILATERAL = 0.127

PROJ_YEARS = [2024, 2030, 2040, 2050]
MAX_DELTA_R = 2.0  # Cap on Δr* (pp), same as multilateral GE model

# Key countries for bilateral analysis
KEY_COUNTRIES = [
    'JPN', 'DEU', 'ITA', 'KOR', 'ESP', 'GBR', 'FRA', 'USA', 'CAN', 'AUS',
    'CHN', 'IND', 'BRA', 'MEX', 'IDN', 'TUR', 'THA', 'VNM',
    'NGA', 'EGY', 'PHL', 'BGD', 'PAK', 'KEN', 'ETH',
    'SAU', 'ARE', 'SGP', 'ZAF', 'RUS', 'POL', 'MYS',
]


def load_data():
    """Load demographics with time-varying GDP."""
    fp = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")

    # GDP map: use actual WEO GDP for years through 2030
    # For 2040/2050: use 2030 GDP (last WEO year)
    gdp_by_year = {}
    for year in PROJ_YEARS:
        lookup_year = min(year, 2030)  # 2040/2050 → use 2030 GDP
        yr_data = fp[fp['year'] == lookup_year][['iso3', 'ngdp_usd']].dropna()
        gdp_by_year[year] = dict(zip(yr_data['iso3'], yr_data['ngdp_usd']))

    # KAOPEN: use latest observed (doesn't project)
    latest = fp[fp['year'] <= 2024].sort_values('year').groupby('iso3').last()
    kaopen_map = latest['kaopen'].to_dict() if 'kaopen' in latest.columns else {}

    # Country names
    name_map = {}
    if 'country_name' in fp.columns:
        name_map = fp.dropna(subset=['country_name']).groupby('iso3')['country_name'].first().to_dict()

    return fp, gdp_by_year, kaopen_map, name_map


def compute_bilateral_effect(z_i, z_j, kaopen_j, gdp_i, gdp_j, gdp_i_base, gdp_j_base):
    """
    Compute the full demographic + GDP component of log(Flow_ij) from Model 2c.

    Returns both the demographic effect and the GDP contribution change.
    """
    dZ_1 = z_i['Z_1'] - z_j['Z_1']
    dZ_2 = z_i['Z_2'] - z_j['Z_2']
    dZ_3 = z_i['Z_3'] - z_j['Z_3']

    # Demographic effect
    demo_effect = (COEFS_2C['dZ_1'] * dZ_1 +
                   COEFS_2C['dZ_2'] * dZ_2 +
                   COEFS_2C['dZ_3'] * dZ_3)

    if not np.isnan(kaopen_j):
        demo_effect += (COEFS_2C['dZ_1_x_kaopen_j'] * dZ_1 * kaopen_j +
                        COEFS_2C['dZ_2_x_kaopen_j'] * dZ_2 * kaopen_j +
                        COEFS_2C['dZ_3_x_kaopen_j'] * dZ_3 * kaopen_j +
                        COEFS_2C['kaopen_j'] * kaopen_j)

    # GDP contribution change (log GDP product relative to baseline)
    gdp_delta = 0.0
    if gdp_i > 0 and gdp_j > 0 and gdp_i_base > 0 and gdp_j_base > 0:
        gdp_delta = COEFS_2C['log_gdp_product'] * (
            np.log(gdp_i * gdp_j) - np.log(gdp_i_base * gdp_j_base))

    return demo_effect, gdp_delta


def compute_demographic_yield_pressure(z_vals):
    """Compute a country's demographic pressure on bond yields from S1."""
    return sum(S1_COEFS[f'Z_{k}'] * z_vals[f'Z_{k}'] for k in [1, 2, 3])


def compute_global_clearing_rate(fp, gdp_by_year, year):
    """
    Compute the global Δr* that clears the GDP-weighted capital market
    for a given projection year.

    Clearing condition:
      sum_i [w_i * (demo_ca_i + delta * (yield_i - Δr*))] = 0
      => Δr* = PE_imbalance / delta + global_yield_effect
    """
    lookup_year_gdp = min(year, 2030)
    gdp_map = gdp_by_year.get(year, gdp_by_year[2024])
    total_gdp = sum(gdp_map.values())
    if total_gdp == 0:
        return 0.0, {}

    # Load Z values for this year (drop duplicate iso3 if any)
    year_data = fp[fp['year'] == year].drop_duplicates(subset='iso3').set_index('iso3')

    # Baseline Z coefficients (from followup multilateral model)
    # Use the bilateral model's implicit CA coefficients via S1
    # For clearing, we use the multilateral framework
    z_ca_betas = {'Z_1': 0.547, 'Z_2': -0.086, 'Z_3': 0.003}  # Followup Model 2 coefficients

    pe_imbalance = 0.0
    global_yield_effect = 0.0
    yield_map = {}

    for iso3, gdp in gdp_map.items():
        if iso3 not in year_data.index:
            continue
        row = year_data.loc[iso3]
        if pd.isna(row.get('Z_1', np.nan)):
            continue

        w = gdp / total_gdp
        z_vals = {f'Z_{k}': row[f'Z_{k}'] for k in [1, 2, 3]}

        # Demographic CA contribution (PE)
        demo_ca = sum(z_ca_betas[f'Z_{k}'] * z_vals[f'Z_{k}'] for k in [1, 2, 3])

        # Demographic yield pressure
        demo_yield = compute_demographic_yield_pressure(z_vals)
        yield_map[iso3] = demo_yield

        pe_imbalance += w * demo_ca
        global_yield_effect += w * demo_yield

    # Clearing rate
    delta_r_uncapped = pe_imbalance / DELTA_RATE_MULTILATERAL + global_yield_effect
    delta_r = np.clip(delta_r_uncapped, -MAX_DELTA_R, MAX_DELTA_R)
    residual = pe_imbalance + DELTA_RATE_MULTILATERAL * (global_yield_effect - delta_r)

    return delta_r, yield_map, pe_imbalance, delta_r_uncapped, residual


def main():
    print("=" * 70)
    print("PHASE 5b: ENHANCED BILATERAL PROJECTIONS (WEO GDP + GE CLEARING)")
    print("=" * 70)

    fp, gdp_by_year, kaopen_map, name_map = load_data()

    # Show GDP evolution for key countries
    print("\nGDP evolution (WEO, billions USD):")
    print(f"  {'Country':<6}  {'2024':>10}  {'2030':>10}  {'Δ%':>8}")
    for iso in ['USA', 'CHN', 'IND', 'KOR', 'DEU', 'JPN', 'NGA', 'ARE']:
        g24 = gdp_by_year[2024].get(iso, 0)
        g30 = gdp_by_year[2030].get(iso, 0)
        pct = (g30 / g24 - 1) * 100 if g24 > 0 else 0
        print(f"  {iso:<6}  {g24:>10.1f}  {g30:>10.1f}  {pct:>+7.1f}%")

    # Filter to countries with Z data
    available = set(fp[fp['year'].isin(PROJ_YEARS) & fp['Z_1'].notna()]['iso3'].unique())
    key_countries = [c for c in KEY_COUNTRIES if c in available]
    print(f"\nKey countries with projection data: {len(key_countries)}")

    # ===================================================================
    # Step 1: Compute global clearing rates for each projection year
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("GLOBAL CLEARING RATES")
    print(f"{'=' * 70}")

    clearing_data = []
    yield_maps = {}
    print(f"  {'Year':>6}  {'PE Imbalance':>14}  {'Δr* uncapped':>14}  {'Δr* (capped)':>14}  {'Residual':>10}")
    for year in PROJ_YEARS:
        delta_r, yield_map, pe_imb, delta_r_unc, residual = compute_global_clearing_rate(
            fp, gdp_by_year, year)
        yield_maps[year] = yield_map
        clearing_data.append({
            'year': year,
            'pe_imbalance': pe_imb,
            'delta_r_uncapped': delta_r_unc,
            'delta_r_world': delta_r,
            'residual': residual,
        })
        cap_note = " [CAPPED]" if abs(delta_r_unc) > MAX_DELTA_R else ""
        print(f"  {year:>6}  {pe_imb:>14.4f}  {delta_r_unc:>14.3f}  {delta_r:>14.3f}  {residual:>10.4f}{cap_note}")

    clearing_df = pd.DataFrame(clearing_data)
    delta_r_map = dict(zip(clearing_df['year'], clearing_df['delta_r_world']))

    # ===================================================================
    # Step 2: Compute bilateral projections (PE + GE)
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("BILATERAL PROJECTIONS")
    print(f"{'=' * 70}")

    results = []
    for year in PROJ_YEARS:
        year_data = fp[fp['year'] == year].drop_duplicates(subset='iso3').set_index('iso3')
        gdp_map = gdp_by_year.get(year, gdp_by_year[2024])
        gdp_base = gdp_by_year[2024]
        yield_map = yield_maps.get(year, {})
        delta_r_world = delta_r_map.get(year, 0.0)

        for i_iso in key_countries:
            if i_iso not in year_data.index:
                continue
            z_i = year_data.loc[i_iso]
            kaopen_i = kaopen_map.get(i_iso, np.nan)
            gdp_i = gdp_map.get(i_iso, 0)
            gdp_i_base = gdp_base.get(i_iso, 0)
            yield_i = yield_map.get(i_iso, 0)

            for j_iso in key_countries:
                if j_iso == i_iso or j_iso not in year_data.index:
                    continue
                z_j = year_data.loc[j_iso]
                kaopen_j = kaopen_map.get(j_iso, np.nan)
                gdp_j = gdp_map.get(j_iso, 0)
                gdp_j_base = gdp_base.get(j_iso, 0)
                yield_j = yield_map.get(j_iso, 0)

                demo_effect, gdp_delta = compute_bilateral_effect(
                    z_i, z_j, kaopen_j, gdp_i, gdp_j, gdp_i_base, gdp_j_base)

                # GE adjustment: rate channel effect net of global clearing
                # Bilateral yield differential pressure
                bilateral_yield_diff = yield_i - yield_j
                # GE adjustment: how much does the rate channel shift this pair?
                # Δlog(Flow_ij)_GE = rate_coef × (Δyield_ij - Δr*_world)
                # But Δr*_world is a uniform shift, so bilateral differential adjusts by:
                # rate_coef × (yield_i - yield_j) with yield_i shifted down by Δr*
                # = rate_coef × ((yield_i - Δr*) - (yield_j - Δr*))
                # = rate_coef × (yield_i - yield_j) [unchanged!]
                # Actually, the GE adjustment to bilateral flows works differently:
                # In PE: fitted_rate_diff_ij = yield_i - yield_j
                # In GE: rates adjust uniformly, so bilateral differential unchanged
                # BUT each country's CA adjusts, shifting savings available for allocation
                #
                # The correct bilateral GE overlay:
                # PE flow includes demographic component computed above
                # GE adjustment = change in world rate × bilateral rate sensitivity
                # Since world rates shift by Δr*, each country's yield shifts by Δr*
                # Bilateral differential (yield_i - yield_j) is unchanged by a uniform shift
                # But the LEVEL of flows adjusts: higher world rates reduce all outflows
                #
                # Approach: Use the multilateral GE CA adjustment for each country
                # to adjust the bilateral allocation proportionally
                ge_ca_adj_i = DELTA_RATE_MULTILATERAL * (yield_i - delta_r_world) if year != 2024 else 0
                ge_ca_adj_j = DELTA_RATE_MULTILATERAL * (yield_j - delta_r_world) if year != 2024 else 0
                # The bilateral GE adjustment reflects that i's savings surplus changes
                # and j's investment demand changes
                # Δlog(Flow_ij)_GE ≈ rate_coef × [(yield_i - Δr*) - (yield_j - Δr*)]
                # = rate_coef × bilateral_yield_diff [same as PE for bilateral diff]
                #
                # The GE effect on bilateral flows operates through the LEVEL:
                # If world rates rise, ALL bilateral flows are dampened
                # Δlog(Flow_ij)_GE_level = BILATERAL_RATE_COEF × (-Δr*_world)
                # (higher world rates → lower bilateral positions)
                ge_level_adjustment = BILATERAL_RATE_COEF * (-delta_r_world) if year != 2024 else 0

                # Total PE effect (demographics + GDP evolution)
                pe_total = demo_effect + gdp_delta
                # GE-adjusted effect
                ge_total = pe_total + ge_level_adjustment

                results.append({
                    'year': year,
                    'reporter': i_iso,
                    'partner': j_iso,
                    'reporter_name': name_map.get(i_iso, i_iso),
                    'partner_name': name_map.get(j_iso, j_iso),
                    'dZ_1': z_i['Z_1'] - z_j['Z_1'],
                    'kaopen_j': kaopen_j,
                    'demo_effect': demo_effect,
                    'gdp_delta': gdp_delta,
                    'pe_total': pe_total,
                    'bilateral_yield_diff': bilateral_yield_diff,
                    'delta_r_world': delta_r_world,
                    'ge_level_adj': ge_level_adjustment,
                    'ge_total': ge_total,
                })

    proj_df = pd.DataFrame(results)
    print(f"Computed {len(proj_df):,} bilateral projections")

    # Compute changes relative to 2024 baseline
    baseline = proj_df[proj_df['year'] == 2024][
        ['reporter', 'partner', 'pe_total', 'ge_total', 'demo_effect']
    ].rename(columns={'pe_total': 'pe_2024', 'ge_total': 'ge_2024', 'demo_effect': 'demo_2024'})
    proj_df = proj_df.merge(baseline, on=['reporter', 'partner'], how='left')
    proj_df['delta_pe'] = proj_df['pe_total'] - proj_df['pe_2024']
    proj_df['delta_ge'] = proj_df['ge_total'] - proj_df['ge_2024']
    proj_df['delta_demo'] = proj_df['demo_effect'] - proj_df['demo_2024']
    proj_df['pct_change_pe'] = (np.exp(proj_df['delta_pe']) - 1) * 100
    proj_df['pct_change_ge'] = (np.exp(proj_df['delta_ge']) - 1) * 100
    proj_df['pct_change_demo'] = (np.exp(proj_df['delta_demo']) - 1) * 100

    # ===================================================================
    # Top bilateral shifts: PE vs GE comparison
    # ===================================================================
    proj_2050 = proj_df[proj_df['year'] == 2050].copy()

    print(f"\n{'=' * 70}")
    print("TOP 15 BILATERAL SHIFTS BY 2050: PE vs GE")
    print(f"{'=' * 70}")
    print(f"  {'Reporter':<5} {'→':>1} {'Partner':<5}  {'Demo %Δ':>9}  {'GDP %Δ':>8}  {'PE %Δ':>8}  {'GE %Δ':>8}  {'GE damp':>8}")
    top = proj_2050.nlargest(15, 'delta_pe')
    for _, row in top.iterrows():
        gdp_pct = (np.exp(row['gdp_delta']) - 1) * 100
        ge_damp = row['pct_change_ge'] - row['pct_change_pe']
        print(f"  {row['reporter']:<5} → {row['partner']:<5}  {row['pct_change_demo']:>+8.1f}%  "
              f"{gdp_pct:>+7.1f}%  {row['pct_change_pe']:>+7.1f}%  {row['pct_change_ge']:>+7.1f}%  "
              f"{ge_damp:>+7.1f}pp")

    # ===================================================================
    # Time paths for selected pairs: PE vs GE
    # ===================================================================
    interesting_pairs = [
        ('KOR', 'IND'), ('KOR', 'NGA'), ('KOR', 'PHL'),
        ('CHN', 'IND'), ('CHN', 'NGA'),
        ('DEU', 'IND'), ('DEU', 'NGA'),
        ('JPN', 'IND'), ('JPN', 'NGA'),
        ('USA', 'IND'), ('USA', 'MEX'),
    ]

    print(f"\n{'=' * 70}")
    print("TIME PATHS: PE vs GE (% change vs 2024)")
    print(f"{'=' * 70}")
    print(f"  {'Pair':<12}  {'':>4}  {'2030':>8}  {'2040':>8}  {'2050':>8}")
    for i_iso, j_iso in interesting_pairs:
        pair_data = proj_df[(proj_df['reporter'] == i_iso) & (proj_df['partner'] == j_iso)]
        if len(pair_data) < 2:
            continue
        pe_vals = {int(r['year']): r['pct_change_pe'] for _, r in pair_data.iterrows() if r['year'] != 2024}
        ge_vals = {int(r['year']): r['pct_change_ge'] for _, r in pair_data.iterrows() if r['year'] != 2024}
        pe_str = "  ".join(f"{pe_vals.get(y, 0):>+7.1f}%" for y in [2030, 2040, 2050])
        ge_str = "  ".join(f"{ge_vals.get(y, 0):>+7.1f}%" for y in [2030, 2040, 2050])
        print(f"  {i_iso}→{j_iso:<6}  {'PE':>4}  {pe_str}")
        print(f"  {'':>12}  {'GE':>4}  {ge_str}")

    # ===================================================================
    # Decomposition: Demographics vs GDP vs GE for 2050
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("DECOMPOSITION OF 2050 PROJECTED CHANGES (avg across partners)")
    print(f"{'=' * 70}")
    print(f"  {'Country':<6}  {'Demo':>8}  {'GDP':>8}  {'GE adj':>8}  {'Total GE':>8}")

    country_decomp = []
    for iso in key_countries:
        out = proj_2050[proj_2050['reporter'] == iso]
        if len(out) == 0:
            continue
        avg_demo = out['delta_demo'].mean()
        avg_gdp = out['gdp_delta'].mean()
        avg_ge_adj = out['ge_level_adj'].mean()
        avg_total = out['delta_ge'].mean()
        country_decomp.append({
            'iso3': iso,
            'avg_demo_delta': avg_demo,
            'avg_gdp_delta': avg_gdp,
            'avg_ge_adjustment': avg_ge_adj,
            'avg_total_ge': avg_total,
        })
        print(f"  {iso:<6}  {avg_demo:>+8.3f}  {avg_gdp:>+8.3f}  {avg_ge_adj:>+8.3f}  {avg_total:>+8.3f}")

    # ===================================================================
    # Net reallocation pressure (PE vs GE)
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("NET REALLOCATION PRESSURE: PE vs GE (2050 vs 2024)")
    print(f"{'=' * 70}")

    outward_pe = proj_2050.groupby('reporter')['delta_pe'].mean()
    received_pe = proj_2050.groupby('partner')['delta_pe'].mean()
    outward_ge = proj_2050.groupby('reporter')['delta_ge'].mean()
    received_ge = proj_2050.groupby('partner')['delta_ge'].mean()

    net_pe = (outward_pe - received_pe).rename('net_pe')
    net_ge = (outward_ge - received_ge).rename('net_ge')
    net_summary = pd.DataFrame({'net_pe': net_pe, 'net_ge': net_ge})
    net_summary['ge_damping'] = net_summary['net_ge'] - net_summary['net_pe']
    net_summary = net_summary.sort_values('net_ge', ascending=False)

    print(f"  {'Country':<6}  {'Net PE':>8}  {'Net GE':>8}  {'GE damp':>8}")
    for iso, row in net_summary.iterrows():
        print(f"  {iso:<6}  {row['net_pe']:>+8.3f}  {row['net_ge']:>+8.3f}  {row['ge_damping']:>+8.3f}")

    # ===================================================================
    # Save all results
    # ===================================================================
    proj_df.to_csv(OUTPUT_DIR / "bilateral_projections_ge.csv", index=False)
    clearing_df.to_csv(OUTPUT_DIR / "ge_bilateral_clearing_rates.csv", index=False)
    net_summary.index.name = 'iso3'
    net_summary.to_csv(OUTPUT_DIR / "projection_summary_ge.csv")
    pd.DataFrame(country_decomp).to_csv(OUTPUT_DIR / "projection_decomposition_2050.csv", index=False)

    print(f"\nSaved: bilateral_projections_ge.csv")
    print(f"Saved: ge_bilateral_clearing_rates.csv")
    print(f"Saved: projection_summary_ge.csv")
    print(f"Saved: projection_decomposition_2050.csv")

    return proj_df, clearing_df, net_summary


if __name__ == "__main__":
    proj_df, clearing_df, net_summary = main()
